#originally from https://github.com/aliutkus/swf
# no known license, however this code has been used and is being used here for research purpose only 

import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, in_size, output_shape):
        super().__init__()
        self.output_shape = output_shape

        self.layers = nn.Sequential(
            nn.Linear(in_size, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, output_shape)
        )

    def forward(self, x):
        return self.layers(x)
    #     self.layers = nn.Sequential(
    #         nn.Linear(in_size, 128),
    #         nn.ReLU(),
    #         nn.Linear(128, 258),
    #         nn.ReLU(),
    #         nn.Linear(258, 512),
    #         nn.ReLU(),
    #         nn.Linear(512, 128),
    #         nn.ReLU(),
    #         nn.Linear(128, output_shape)
    #     )

    # def forward(self, x):
    #     return self.layers(x)

# class Generator(nn.Module):
#     def __init__(self, in_size, output_shape):
#         super().__init__()
#         self.output_shape = output_shape
#         d = output_shape[-1]
#         self.fc4 = nn.Linear(in_size, int(d / 2 * d / 2 * d))
#         deconv1 = nn.ConvTranspose2d(d, d, 3, stride=1, padding=1)
#         deconv2 = nn.ConvTranspose2d(d, d, 3, stride=1, padding=1)
#         deconv3 = nn.ConvTranspose2d(d, d, 2, stride=2, padding=0)
#         conv5 = nn.Conv2d(d, self.output_shape[0], 3, stride=1, padding=1)
#         relu = nn.ReLU(inplace=True)
#         sigmoid = nn.Sigmoid()
#         self.conv_network = nn.Sequential(deconv1, relu, deconv2, relu, deconv3, relu, conv5, sigmoid)

#     def forward(self, x):
#         d = self.output_shape[-1]
#         out = torch.relu(self.fc4(x))
#         out = out.view(-1, d, int(d / 2), int(d / 2))
#         return self.conv_network(out)


# class Generator(nn.Module):
#     def __init__(self, input_shape, bottleneck_size=64):
#         super(ConvDecoder, self).__init__()
#         self.input_shape = input_shape
#         d = input_shape[-1]

#         self.fc4 = nn.Linear(bottleneck_size, int(d/2 * d/2 * d))
#         self.deconv1 = nn.ConvTranspose2d(d, d,
#                                           kernel_size=3, stride=1, padding=1)
#         self.deconv2 = nn.ConvTranspose2d(d, d,
#                                           kernel_size=3, stride=1, padding=1)
#         self.deconv3 = nn.ConvTranspose2d(d, d,
#                                           kernel_size=2, stride=2, padding=0)
#         self.conv5 = nn.Conv2d(d, self.input_shape[0],
#                                kernel_size=3, stride=1, padding=1)

#     def forward(self, x):
#         d = self.input_shape[-1]
#         out = torch.relu(self.fc4(x))
#         out = out.view(-1, d, int(d/2), int(d/2))
#         out = torch.relu(self.deconv1(out))
#         out = torch.relu(self.deconv2(out))
#         out = torch.relu(self.deconv3(out))
#         return torch.sigmoid(self.conv5(out))